{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2025-02-22T07:11:40.468932Z",
     "start_time": "2025-02-22T07:11:36.533710Z"
    }
   },
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "print(sys.version_info)\n",
    "for i in mpl, np, pd, sklearn, torch:\n",
    "    print(i.__name__,i.__version__)\n",
    "    \n",
    "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "print(device)\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sys.version_info(major=3, minor=12, micro=3, releaselevel='final', serial=0)\n",
      "matplotlib 3.10.0\n",
      "numpy 1.26.4\n",
      "pandas 2.2.3\n",
      "sklearn 1.6.1\n",
      "torch 2.6.0+cu118\n",
      "cuda:0\n"
     ]
    }
   ],
   "execution_count": 1
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## 加载数据",
   "id": "8be9f69ca4485c79"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-22T07:11:42.169376Z",
     "start_time": "2025-02-22T07:11:40.470386Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from torchvision import datasets\n",
    "from torchvision.transforms import ToTensor\n",
    "\n",
    "# fashion_mnist图像分类数据集\n",
    "train_ds = datasets.FashionMNIST(\n",
    "    root=\"data\",\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=ToTensor()\n",
    ")\n",
    "\n",
    "test_ds = datasets.FashionMNIST(\n",
    "    root=\"data\",\n",
    "    train=False,\n",
    "    download=True,\n",
    "    transform=ToTensor()\n",
    ")\n",
    "\n",
    "# torchvision 数据集里没有提供训练集和验证集的划分\n",
    "# 当然也可以用 torch.utils.data.Dataset 实现人为划分\n",
    "# 从数据集到dataloader\n",
    "train_loader = torch.utils.data.DataLoader(train_ds, batch_size=16, shuffle=True)\n",
    "val_loader = torch.utils.data.DataLoader(test_ds, batch_size=16, shuffle=False)\n",
    "print(type(train_ds))\n",
    "print(train_ds[0][0].shape)  # 查看第一个样本的形状\n",
    "print(train_ds.classes)       # 查看类别标签\n",
    "print('$'*100)\n",
    "print(train_ds.targets[0:10])\n",
    "print('$'*100)\n",
    "plt.imshow(train_ds[0][0].permute(1,2,0))  # 显示第一个样本\n",
    "train_ds"
   ],
   "id": "7a5fbd08f2673097",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'torchvision.datasets.mnist.FashionMNIST'>\n",
      "torch.Size([1, 28, 28])\n",
      "['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']\n",
      "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n",
      "tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5])\n",
      "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Dataset FashionMNIST\n",
       "    Number of datapoints: 60000\n",
       "    Root location: data\n",
       "    Split: Train\n",
       "    StandardTransform\n",
       "Transform: ToTensor()"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ],
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAIGZJREFUeJzt3Q1wFHW67/FnZvIOeSG85EXCu4AKxJUFRBRRuESs6wXleGX11oG9Fh5YsBZYVyt7VGR362QX67hcXRbOqbMLa5WCUlfkyHG5KyBhUdAF5KKryyFsFBDCm+Y9k8xk+ta/uYlEAXmaJP/JzPdT1TWZmX7optMzv+nu/zzxOY7jCAAAnczf2QsEAMAggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYkSBRJhKJyIkTJyQ9PV18Pp/t1QEAKJn+BjU1NZKfny9+v7/rBJAJn4KCAturAQC4SseOHZO+fft2nQAyRz7GrXK3JEii7dUBACiFJSS75M3W9/NOD6CVK1fKs88+KxUVFVJYWCgvvPCCjB079lvrWk67mfBJ8BFAANDl/P8Oo992GaVDBiG88sorsmTJElm6dKns37/fDaCioiI5ffp0RywOANAFdUgAPffcczJ37lz5/ve/L9dff72sXr1a0tLS5He/+11HLA4A0AW1ewA1NTXJvn37ZMqUKV8txO937+/evfsb8zc2Nkp1dXWbCQAQ+9o9gM6ePSvNzc2Sk5PT5nFz31wP+rqSkhLJzMxsnRgBBwDxwfoXUYuLi6Wqqqp1MsP2AACxr91HwfXq1UsCgYCcOnWqzePmfm5u7jfmT05OdicAQHxp9yOgpKQkGT16tGzbtq1NdwNzf/z48e29OABAF9Uh3wMyQ7Bnz54t3/3ud93v/qxYsULq6urcUXEAAHRYAD3wwANy5swZefrpp92BBzfeeKNs2bLlGwMTAADxy+eYrnFRxAzDNqPhJsl0OiEAQBcUdkKyQza5A8syMjKidxQcACA+EUAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsS7CwWiFI+n77GcaQzBHpmq2u+LBrqaVkZL++RaN3evoREdY0TapKY4/Owr3rVQfs4R0AAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAXNSIEL+AIBdY0TDqtr/Dder6755B+665fTIJ4k1o1V1yQ0RPTL+ePe6G4s6qVZqod9SHz+qN4OvgRdVPhM89IreFlwBAQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVtCMFLiKpotem5EeK8pS1zw0/k/qmnfODBIvPkvOVdc4qfrlJEwZr64Z+pvP1TXhT4+KJ6apZifsD14EevTwVCfNzfqS6mrV/I5zZduAIyAAgBUEEAAgNgLomWeeEZ/P12YaPnx4ey8GANDFdcg1oBtuuEG2bt361UI8nFcHAMS2DkkGEzi5ufqLmACA+NEh14AOHz4s+fn5MmjQIHnooYfk6NFLj0BpbGyU6urqNhMAIPa1ewCNGzdO1q5dK1u2bJFVq1ZJeXm53HbbbVJTU3PR+UtKSiQzM7N1KigoaO9VAgDEQwBNmzZN7r//fhk1apQUFRXJm2++KZWVlfLqq69edP7i4mKpqqpqnY4dO9beqwQAiEIdPjogKytLhg4dKmVlZRd9Pjk52Z0AAPGlw78HVFtbK0eOHJG8vLyOXhQAIJ4D6LHHHpPS0lL59NNP5d1335V7771XAoGAfO9732vvRQEAurB2PwV3/PhxN2zOnTsnvXv3lltvvVX27Nnj/gwAQIcF0Pr169v7nwQ6TSQY7JTlNH2nVl3zd5l71TUp/pB4UeqPqGs+364fwdo8Sr8dPnsuXV0T+eAW8aLnR/rGnRkfnFTXnJ14jbrmzGh9o1QjZ4++psfWI6r5nUiTyNlvn49ecAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCAAQm3+QDrDC5/NW5+gbPNb+95vVNX9//Q51zZGQvqN836QvxIv78/fpi/6HvubXh25X19T9LVNd4+/mrXFnxc36z+ifT9f/npxQWF3TY7+3t2//7FPqmuqmQar5w6GgyKYrWBf1mgAA0A4IIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwgm7Y6BpdqqPYzU+8r665o/vH0hmuEW9doOucJHVNZXM3dc3S6/9DXXNmaLq6JuR4e6v7t8O3qGtqPXTrDoT1r4ub/+cH4sXM7D+ra5b/75Gq+cNO6Irm4wgIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKygGSk6l+OtOWY0O1zbR11zLqO7uqYinKWu6RmoFS/S/Q3qmgGJZ9U1Z5r1jUUDiRF1TZMTEC+W3fCGuiZ4XaK6JtHXrK65JeWEeHH/x3+vrukmf5OOwBEQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBM1LgKvVO1jf8TPGF1DVJvrC65kSoh3hxuGGYuuY/q/VNWe/K+Yu6JuShsWhAvDXB9dIkND/xS3VN0NE3MNXvQedNyNE3Fj0gHYMjIACAFQQQAKBrBNDOnTvlnnvukfz8fPH5fPL666+3ed5xHHn66aclLy9PUlNTZcqUKXL48OH2XGcAQDwGUF1dnRQWFsrKlSsv+vzy5cvl+eefl9WrV8t7770n3bp1k6KiIgkGg+2xvgCAeB2EMG3aNHe6GHP0s2LFCnnyySdl+vTp7mMvvvii5OTkuEdKs2bNuvo1BgDEhHa9BlReXi4VFRXuabcWmZmZMm7cONm9e/dFaxobG6W6urrNBACIfe0aQCZ8DHPEcyFzv+W5ryspKXFDqmUqKChoz1UCAEQp66PgiouLpaqqqnU6duyY7VUCAHS1AMrNzXVvT5061eZxc7/lua9LTk6WjIyMNhMAIPa1awANHDjQDZpt27a1Pmau6ZjRcOPHj2/PRQEA4m0UXG1trZSVlbUZeHDgwAHJzs6Wfv36yaJFi+TnP/+5XHvttW4gPfXUU+53hmbMmNHe6w4AiKcA2rt3r9xxxx2t95csWeLezp49W9auXSuPP/64+12hRx55RCorK+XWW2+VLVu2SEpKSvuuOQCgS/M55ss7UcScsjOj4SbJdEnw6Rv0Icr5fPqSgL75pBPWN+40Aj30zTtn7f5Qvxyf/mV3JpyurskK1IsXpZX6ZqR/OXfx67yX89Nh/66u2V8/QF2Tn6RvEOp1+33a1Etdc23yxUcJX84fviwULwpSvlDX/HHRRNX84XBQdu1Y5g4su9x1feuj4AAA8YkAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAICu8ecYgKviofm6LyGh07phH3v4OnXNnWlvqGveDV6jrumdUKOuCTn6TuJGXnKVuiY9J6iuqWxOU9dkJ9Sqa2qaU8WLNH9jp/yebko6q65ZvPUm8SJ9xDl1TUai7lglcoXHNhwBAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVNCNFp/IlJqlrIkF9k0uven3YpK4525yorsny16trknzN6pomj81Ib8kuV9ec8dDwc3/DQHVNeqBBXdPbr28QahQk6ht3fhgsUNe8WTdEXfPwf90qXqz71/+irkna8q5qfr8TurL51GsCAEA7IIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAV8d2M1OfzVpagbz7pC3jIer++JhJs1C8nom9y6ZUT0jf77Ez/619+ra45Fs5S11SE9DVZAX0D02bxto/vachU16T4r6wB5YV6J1Sra6oj+qanXtVEUtQ1IQ8NYFM8bLsneh4WL16rmiLRgiMgAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALAiZpqR+hL0/xUnHO60hpqOvtdgTGqYPlZdc2yGvlnqQ995X7yoCKeraz6oH6CuyQw0qGu6+fWNZoOOvnGucaKpR6c01MxOqFXX9PHQwLTZ8fZZ+/OQfjt4keWh0ezxsH7bGTX/rUZdk/WidAiOgAAAVhBAAICuEUA7d+6Ue+65R/Lz88Xn88nrr7/e5vk5c+a4j1843XXXXe25zgCAeAyguro6KSwslJUrV15yHhM4J0+ebJ3WrVt3tesJAIgx6iv306ZNc6fLSU5Oltzc3KtZLwBAjOuQa0A7duyQPn36yLBhw2T+/Ply7ty5S87b2Ngo1dXVbSYAQOxr9wAyp99efPFF2bZtm/zyl7+U0tJS94ipufniQ2lLSkokMzOzdSooKGjvVQIAxMP3gGbNmtX688iRI2XUqFEyePBg96ho8uTJ35i/uLhYlixZ0nrfHAERQgAQ+zp8GPagQYOkV69eUlZWdsnrRRkZGW0mAEDs6/AAOn78uHsNKC8vr6MXBQCI5VNwtbW1bY5mysvL5cCBA5Kdne1Oy5Ytk5kzZ7qj4I4cOSKPP/64DBkyRIqKitp73QEA8RRAe/fulTvuuKP1fsv1m9mzZ8uqVavk4MGD8vvf/14qKyvdL6tOnTpVfvazn7mn2gAAaOFzHMeRKGIGIZjRcJNkuiT4vDVSjEYJefrvRYUG5qhrvrguTV1Tn+sTL268+xN1zZycXeqaM83664KJPm+NZmuaU9U1uYmV6prtVdera7onNHZK01PjptRP1TWVEf2+l5/wpbrmibK/U9fkpOkbcBr/1v9NdU3IiahrDoX0H9DT/fqmyMaf6oeoazZe31s1f9gJyQ7ZJFVVVZe9rk8vOACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAMTGn+S2pXHaGHVNn3/8m6dl3ZhxXF1zfaq+C3Qwou8GnuIPqWs+brhGvKiPJKlrDjfpu4JXhfVdlgM+fUdi43RTurrmn8unqGu2jV2trnnyxF3qGn+qt2b355q7q2tmdq/2sCT9Pv4P/XaqawYlnRYvNtfp/5DmiVAPdU1OYpW6ZkDiGfHivvT/VNdsFF037CvFERAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWBG1zUh9CQni81356o37pz+rlzE5/S/iRb2T3CmNRb00NfQiM6HeU11jSL/7nA5lSGcYmlzhqe7ejAPqmp2/HqeuuTX4qLrmyJ1r1DXbGgLixZmw/vc0q/xOdc3+owXqmpsHlKtrRqZ/Ll54aYSbHgiqaxJ9YXVNXUT/PmTsCeobzXYUjoAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwIqobUZ6cv5oCSSnXPH8z2S+oF7Gy1/cLF4UpHyhrumfdFZdU5j6mXSGdL++eaIxLEPfQHFzXV91zY7K4eqavMRK8eJP9YPVNeufeVZdM2fxj9Q149+cp66pHuDtM2a4m6OuySg8p6558jv/oa5J8jWrayqb9U1FjezkOnVNVsBbc9/OaIpspPsb1DWBYUNU8zvNjSKHv30+joAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwIqobUaadjoigaTIFc+/ufpG9TIGpZ4RL86G0tU1/6d2pLqmb+qX6prMgL7R4JDkCvHiQDBLXbPlzA3qmvzUanXNqVCmeHEu1E1dUx/RN4X87a+eU9f886kp6pp7s/eLF4VJ+sailRH959mPm3LVNTWRK29S3CLoJIoXVR6amKZ7eA2GHP1bccC58vfHC2X59c1Sq0f2VM0fDgVpRgoAiF4EEAAg+gOopKRExowZI+np6dKnTx+ZMWOGHDp0qM08wWBQFixYID179pTu3bvLzJkz5dSpU+293gCAeAqg0tJSN1z27Nkjb731loRCIZk6darU1X31R5sWL14sb7zxhmzYsMGd/8SJE3Lfffd1xLoDALow1ZWvLVu2tLm/du1a90ho3759MnHiRKmqqpLf/va38vLLL8udd97pzrNmzRq57rrr3NC6+WZvf4EUABB7ruoakAkcIzs72701QWSOiqZM+Wq0zvDhw6Vfv36ye/fui/4bjY2NUl1d3WYCAMQ+zwEUiURk0aJFMmHCBBkxYoT7WEVFhSQlJUlWVtvhuTk5Oe5zl7qulJmZ2ToVFBR4XSUAQDwEkLkW9NFHH8n69euvagWKi4vdI6mW6dixY1f17wEAYviLqAsXLpTNmzfLzp07pW/fvq2P5+bmSlNTk1RWVrY5CjKj4MxzF5OcnOxOAID4ojoCchzHDZ+NGzfK9u3bZeDAgW2eHz16tCQmJsq2bdtaHzPDtI8ePSrjx49vv7UGAMTXEZA57WZGuG3atMn9LlDLdR1z7SY1NdW9ffjhh2XJkiXuwISMjAx59NFH3fBhBBwAwHMArVq1yr2dNGlSm8fNUOs5c+a4P//qV78Sv9/vfgHVjHArKiqS3/zmN5rFAADigM8x59WiiBmGbY6kJt76lCQkXHnTwTEr9qmX9VF1vniRk1KjrhnV/bi65lC9vlHjiYYMdU1aQki8SA3o68KOftxLn2T99u6XrG+maaT79Y0kk3zN6ppmD+N/bkg6oa45Gu4hXlSE9Y1mP67Xv556JOgbY37o4XVbH04SLxqb9ZfJg2F9TWZyUF0zJvsz8cIv+rf8l//9dtX8kWBQ/vbzf3QHlpkzYZdeFwAALCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAKDr/EXUzuDfdVD8vsQrnn/DHyeol/HU9A3iRWnlcHXN5oqR6prqJv1fiu2dVqeuyUjUd5s2shP1y8r00P04xRdW13wZ7iZeNPqvfJ9r0Sw+dU1FY6a65p3IteqaUCQgXjR6qPPSHf2Lpl7qmvzUKnVNTfjKO+tf6NOabHXN2aru6ppgmv6teFfzYPHirty/qGtST+v28ebGK5ufIyAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsMLnOI4jUaS6uloyMzNlkkyXBEUzUi+qHrrZU92gHxxS14zNKlfX7K/up6456qF5Yiji7XNIoj+irklLbFLXpHhocpkUaBYv/KJ/OUQ8NCPtFtBvh24JjeqajISgeJEe0Nf5ffr9wYuAh9/R+1UDpLOke/g9hR39a3B85hHx4nflt6hrMu8uU80fdkKyQzZJVVWVZGRkXHI+joAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwIrobUbqv0/XjDTirflkZ6mbOU5dM+4nf9bXpOsbFA5POiVeJIq++WSKh4aV3fz6Zp9Bj7u1l09kuxoK1DXNHpa0/cvr1DUhD00ujVP1l24geSmJHhvAakUc/f7QEPbW2LiqIUVdE/Dr973gjl7qmp4f65v0Gslv6t9XtGhGCgCIagQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwInqbkcp0XTNSeOYbM9JTXUNuqrom+Vyjuqamv345GUfqxAt/Y1hdE/m/n3haFhCraEYKAIhqBBAAIPoDqKSkRMaMGSPp6enSp08fmTFjhhw6dKjNPJMmTRKfz9dmmjdvXnuvNwAgngKotLRUFixYIHv27JG33npLQqGQTJ06Verq2p5vnzt3rpw8ebJ1Wr58eXuvNwCgi0vQzLxly5Y299euXeseCe3bt08mTpzY+nhaWprk5ua231oCAGLOVV0DMiMcjOzs7DaPv/TSS9KrVy8ZMWKEFBcXS319/SX/jcbGRnfk24UTACD2qY6ALhSJRGTRokUyYcIEN2haPPjgg9K/f3/Jz8+XgwcPyhNPPOFeJ3rttdcueV1p2bJlXlcDABBv3wOaP3++/OEPf5Bdu3ZJ3759Lznf9u3bZfLkyVJWViaDBw++6BGQmVqYI6CCggK+B9SJ+B7QV/geENB53wPydAS0cOFC2bx5s+zcufOy4WOMGzfOvb1UACUnJ7sTACC+qALIHCw9+uijsnHjRtmxY4cMHDjwW2sOHDjg3ubl5XlfSwBAfAeQGYL98ssvy6ZNm9zvAlVUVLiPm9Y5qampcuTIEff5u+++W3r27OleA1q8eLE7Qm7UqFEd9X8AAMR6AK1atar1y6YXWrNmjcyZM0eSkpJk69atsmLFCve7QeZazsyZM+XJJ59s37UGAMTfKbjLMYFjvqwKAECHDcNG7HD+/KGnuhTpHBnvdtKCzIi2zlsUEPdoRgoAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGBFgkQZx3Hc27CERM7/CADoQtz37wvez7tMANXU1Li3u+RN26sCALjK9/PMzMxLPu9zvi2iOlkkEpETJ05Ienq6+Hy+Ns9VV1dLQUGBHDt2TDIyMiResR3OYzucx3Y4j+0QPdvBxIoJn/z8fPH7/V3nCMisbN++fS87j9mo8byDtWA7nMd2OI/tcB7bITq2w+WOfFowCAEAYAUBBACwoksFUHJysixdutS9jWdsh/PYDuexHc5jO3S97RB1gxAAAPGhSx0BAQBiBwEEALCCAAIAWEEAAQCs6DIBtHLlShkwYICkpKTIuHHj5P3335d488wzz7jdIS6chg8fLrFu586dcs8997jfqjb/59dff73N82YczdNPPy15eXmSmpoqU6ZMkcOHD0u8bYc5c+Z8Y/+46667JJaUlJTImDFj3E4pffr0kRkzZsihQ4fazBMMBmXBggXSs2dP6d69u8ycOVNOnTol8bYdJk2a9I39Yd68eRJNukQAvfLKK7JkyRJ3aOH+/fulsLBQioqK5PTp0xJvbrjhBjl58mTrtGvXLol1dXV17u/cfAi5mOXLl8vzzz8vq1evlvfee0+6devm7h/mjSietoNhAufC/WPdunUSS0pLS91w2bNnj7z11lsSCoVk6tSp7rZpsXjxYnnjjTdkw4YN7vymtdd9990n8bYdjLlz57bZH8xrJao4XcDYsWOdBQsWtN5vbm528vPznZKSEieeLF261CksLHTimdllN27c2Ho/Eok4ubm5zrPPPtv6WGVlpZOcnOysW7fOiZftYMyePduZPn26E09Onz7tbovS0tLW331iYqKzYcOG1nk++eQTd57du3c78bIdjNtvv9354Q9/6ESzqD8Campqkn379rmnVS7sF2fu7969W+KNObVkTsEMGjRIHnroITl69KjEs/LycqmoqGizf5geVOY0bTzuHzt27HBPyQwbNkzmz58v586dk1hWVVXl3mZnZ7u35r3CHA1cuD+Y09T9+vWL6f2h6mvbocVLL70kvXr1khEjRkhxcbHU19dLNIm6ZqRfd/bsWWlubpacnJw2j5v7f/3rXyWemDfVtWvXum8u5nB62bJlctttt8lHH33knguORyZ8jIvtHy3PxQtz+s2caho4cKAcOXJEfvKTn8i0adPcN95AICCxxnTOX7RokUyYMMF9gzXM7zwpKUmysrLiZn+IXGQ7GA8++KD079/f/cB68OBBeeKJJ9zrRK+99ppEi6gPIHzFvJm0GDVqlBtIZgd79dVX5eGHH7a6brBv1qxZrT+PHDnS3UcGDx7sHhVNnjxZYo25BmI+fMXDdVAv2+GRRx5psz+YQTpmPzAfTsx+EQ2i/hScOXw0n96+PorF3M/NzZV4Zj7lDR06VMrKyiRetewD7B/fZE7TmtdPLO4fCxculM2bN8vbb7/d5s+3mN+5OW1fWVkZF/vDwktsh4sxH1iNaNofoj6AzOH06NGjZdu2bW0OOc398ePHSzyrra11P82YTzbxypxuMm8sF+4f5g9ymdFw8b5/HD9+3L0GFEv7hxl/Yd50N27cKNu3b3d//xcy7xWJiYlt9gdz2slcK42l/cH5lu1wMQcOHHBvo2p/cLqA9evXu6Oa1q5d63z88cfOI4884mRlZTkVFRVOPPnRj37k7NixwykvL3feeecdZ8qUKU6vXr3cETCxrKamxvnggw/cyeyyzz33nPvzZ5995j7/i1/8wt0fNm3a5Bw8eNAdCTZw4ECnoaHBiZftYJ577LHH3JFeZv/YunWrc9NNNznXXnutEwwGnVgxf/58JzMz030dnDx5snWqr69vnWfevHlOv379nO3btzt79+51xo8f706xZP63bIeysjLnpz/9qfv/N/uDeW0MGjTImThxohNNukQAGS+88IK7UyUlJbnDsvfs2ePEmwceeMDJy8tzt8E111zj3jc7Wqx7++233Tfcr09m2HHLUOynnnrKycnJcT+oTJ482Tl06JATT9vBvPFMnTrV6d27tzsMuX///s7cuXNj7kPaxf7/ZlqzZk3rPOaDxw9+8AOnR48eTlpamnPvvfe6b87xtB2OHj3qhk12drb7mhgyZIjz4x//2KmqqnKiCX+OAQBgRdRfAwIAxCYCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAiA3/D4e0yFBlkdT0AAAAAElFTkSuQmCC"
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "execution_count": 2
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-22T07:25:50.040654Z",
     "start_time": "2025-02-22T07:25:45.020643Z"
    }
   },
   "cell_type": "code",
   "source": [
    "count = 1\n",
    "for img, _ in train_loader:\n",
    "    if count < 2 :\n",
    "        print(img)\n",
    "        print(_)\n",
    "        count += 1"
   ],
   "id": "8af9cf7157763f3d",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "          ...,\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]],\n",
      "\n",
      "\n",
      "        [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "          ...,\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]],\n",
      "\n",
      "\n",
      "        [[[0.0000, 0.0000, 0.0000,  ..., 0.0118, 0.0000, 0.0000],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0314, 0.0000, 0.0000],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0039, 0.0000],\n",
      "          ...,\n",
      "          [0.2510, 0.7373, 0.7137,  ..., 0.0000, 0.0392, 0.0000],\n",
      "          [0.0000, 0.4157, 0.6039,  ..., 0.0000, 0.0353, 0.0000],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0235, 0.0000]]],\n",
      "\n",
      "\n",
      "        ...,\n",
      "\n",
      "\n",
      "        [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.3176, 0.0000, 0.0118],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.2353, 0.0000, 0.0157],\n",
      "          ...,\n",
      "          [0.0000, 0.6667, 0.9020,  ..., 0.7922, 0.7569, 0.5765],\n",
      "          [0.0000, 0.0118, 0.3961,  ..., 0.7647, 0.7843, 0.5137],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0118, 0.0078, 0.0000]]],\n",
      "\n",
      "\n",
      "        [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "          ...,\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]],\n",
      "\n",
      "\n",
      "        [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "          ...,\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]]])\n",
      "tensor([9, 4, 5, 1, 1, 1, 0, 9, 2, 7, 5, 4, 6, 9, 4, 2])\n"
     ]
    }
   ],
   "execution_count": 8
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-22T07:11:54.459786Z",
     "start_time": "2025-02-22T07:11:47.340372Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from torchvision.transforms import Normalize\n",
    "\n",
    "# 遍历train_ds得到每张图片，计算每个通道的均值和方差\n",
    "def cal_mean_std(ds):\n",
    "    mean = 0.\n",
    "    std = 0.\n",
    "    for img, _ in ds: # 遍历每张图片,img.shape=[1,28,28]\n",
    "        mean += img.mean(dim=(1, 2))\n",
    "        std += img.std(dim=(1, 2))\n",
    "    mean /= len(ds)\n",
    "    std /= len(ds)\n",
    "    return mean, std\n",
    "\n",
    "\n",
    "print(cal_mean_std(train_ds))\n",
    "# 0.2860， 0.3205\n",
    "transforms = nn.Sequential(\n",
    "    Normalize([0.2860], [0.3205]) # 这里的均值和标准差是通过train_ds计算得到的\n",
    ")\n"
   ],
   "id": "6028d0586d07f9f9",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(tensor([0.2860]), tensor([0.3205]))\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## 定义模型\n",
    "\n",
    "这里我们没有用`nn.Linear`的默认初始化，而是采用了xavier均匀分布去初始化全连接层的权重\n",
    "\n",
    "xavier初始化出自论文 《Understanding the difficulty of training deep feedforward neural networks》，适用于使用`tanh`和`sigmoid`激活函数的方法。当然，我们这里的模型采用的是`relu`激活函数，采用He初始化（何凯明初始化）会更加合适。感兴趣的同学可以自己动手修改并比对效果。\n",
    "\n",
    "|神经网络层数|初始化方式|early stop at epoch| val_loss | vla_acc|\n",
    "|-|-|-|-|-|\n",
    "|20|默认|\n",
    "|20|xaviier_uniform|\n",
    "|20|he_uniform|\n",
    "|...|\n",
    "\n",
    "He初始化出自论文 《Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification》"
   ],
   "id": "f4895a9a040d93d8"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-22T07:11:54.477117Z",
     "start_time": "2025-02-22T07:11:54.460947Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class NeuralNetwork(nn.Module):\n",
    "    def __init__(self, layers_num=2):\n",
    "        super().__init__()\n",
    "        self.transforms = transforms # 预处理层，标准化\n",
    "        self.flatten = nn.Flatten()\n",
    "        # 多加几层\n",
    "        self.linear_relu_stack = nn.Sequential(\n",
    "            nn.Linear(28 * 28, 100),\n",
    "            nn.ReLU(),\n",
    "        )\n",
    "        # 加19层\n",
    "        for i in range(1, layers_num):\n",
    "            self.linear_relu_stack.add_module(f\"Linear_{i}\", nn.Linear(100, 100))\n",
    "            self.linear_relu_stack.add_module(f\"relu\", nn.ReLU())\n",
    "        # 输出层\n",
    "        self.linear_relu_stack.add_module(\"Output Layer\", nn.Linear(100, 10))\n",
    "        \n",
    "        # 初始化权重\n",
    "        self.init_weights()\n",
    "        \n",
    "    def init_weights(self):\n",
    "        \"\"\"使用 xavier 均匀分布来初始化全连接层的权重 W\"\"\"\n",
    "        # print('''初始化权重''')\n",
    "        for m in self.modules():\n",
    "            # print(m)\n",
    "            # print('-'*50)\n",
    "            if isinstance(m, nn.Linear):#判断m是否为全连接层\n",
    "                # https://pytorch.org/docs/stable/nn.init.html\n",
    "                nn.init.xavier_uniform_(m.weight) # xavier 均匀分布初始化权重\n",
    "                nn.init.zeros_(m.bias) # 全零初始化偏置项\n",
    "        # print('''初始化权重完成''')\n",
    "    def forward(self, x):\n",
    "        # x.shape [batch size, 1, 28, 28]\n",
    "        x = self.transforms(x) #标准化\n",
    "        x = self.flatten(x)  \n",
    "        # 展平后 x.shape [batch size, 28 * 28]\n",
    "        logits = self.linear_relu_stack(x)\n",
    "        # logits.shape [batch size, 10]\n",
    "        return logits\n",
    "total=0\n",
    "for idx, (key, value) in enumerate(NeuralNetwork(20).named_parameters()):\n",
    "    print(f\"Linear_{idx // 2:>02}\\tparamerters num: {np.prod(value.shape)}\") #np.prod是计算张量的元素个数\n",
    "    print(f\"Linear_{idx // 2:>02}\\tshape: {value.shape}\")\n",
    "    total+=np.prod(value.shape)\n",
    "total"
   ],
   "id": "11ebd87b0b88b1d0",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Linear_00\tparamerters num: 78400\n",
      "Linear_00\tshape: torch.Size([100, 784])\n",
      "Linear_00\tparamerters num: 100\n",
      "Linear_00\tshape: torch.Size([100])\n",
      "Linear_01\tparamerters num: 10000\n",
      "Linear_01\tshape: torch.Size([100, 100])\n",
      "Linear_01\tparamerters num: 100\n",
      "Linear_01\tshape: torch.Size([100])\n",
      "Linear_02\tparamerters num: 10000\n",
      "Linear_02\tshape: torch.Size([100, 100])\n",
      "Linear_02\tparamerters num: 100\n",
      "Linear_02\tshape: torch.Size([100])\n",
      "Linear_03\tparamerters num: 10000\n",
      "Linear_03\tshape: torch.Size([100, 100])\n",
      "Linear_03\tparamerters num: 100\n",
      "Linear_03\tshape: torch.Size([100])\n",
      "Linear_04\tparamerters num: 10000\n",
      "Linear_04\tshape: torch.Size([100, 100])\n",
      "Linear_04\tparamerters num: 100\n",
      "Linear_04\tshape: torch.Size([100])\n",
      "Linear_05\tparamerters num: 10000\n",
      "Linear_05\tshape: torch.Size([100, 100])\n",
      "Linear_05\tparamerters num: 100\n",
      "Linear_05\tshape: torch.Size([100])\n",
      "Linear_06\tparamerters num: 10000\n",
      "Linear_06\tshape: torch.Size([100, 100])\n",
      "Linear_06\tparamerters num: 100\n",
      "Linear_06\tshape: torch.Size([100])\n",
      "Linear_07\tparamerters num: 10000\n",
      "Linear_07\tshape: torch.Size([100, 100])\n",
      "Linear_07\tparamerters num: 100\n",
      "Linear_07\tshape: torch.Size([100])\n",
      "Linear_08\tparamerters num: 10000\n",
      "Linear_08\tshape: torch.Size([100, 100])\n",
      "Linear_08\tparamerters num: 100\n",
      "Linear_08\tshape: torch.Size([100])\n",
      "Linear_09\tparamerters num: 10000\n",
      "Linear_09\tshape: torch.Size([100, 100])\n",
      "Linear_09\tparamerters num: 100\n",
      "Linear_09\tshape: torch.Size([100])\n",
      "Linear_10\tparamerters num: 10000\n",
      "Linear_10\tshape: torch.Size([100, 100])\n",
      "Linear_10\tparamerters num: 100\n",
      "Linear_10\tshape: torch.Size([100])\n",
      "Linear_11\tparamerters num: 10000\n",
      "Linear_11\tshape: torch.Size([100, 100])\n",
      "Linear_11\tparamerters num: 100\n",
      "Linear_11\tshape: torch.Size([100])\n",
      "Linear_12\tparamerters num: 10000\n",
      "Linear_12\tshape: torch.Size([100, 100])\n",
      "Linear_12\tparamerters num: 100\n",
      "Linear_12\tshape: torch.Size([100])\n",
      "Linear_13\tparamerters num: 10000\n",
      "Linear_13\tshape: torch.Size([100, 100])\n",
      "Linear_13\tparamerters num: 100\n",
      "Linear_13\tshape: torch.Size([100])\n",
      "Linear_14\tparamerters num: 10000\n",
      "Linear_14\tshape: torch.Size([100, 100])\n",
      "Linear_14\tparamerters num: 100\n",
      "Linear_14\tshape: torch.Size([100])\n",
      "Linear_15\tparamerters num: 10000\n",
      "Linear_15\tshape: torch.Size([100, 100])\n",
      "Linear_15\tparamerters num: 100\n",
      "Linear_15\tshape: torch.Size([100])\n",
      "Linear_16\tparamerters num: 10000\n",
      "Linear_16\tshape: torch.Size([100, 100])\n",
      "Linear_16\tparamerters num: 100\n",
      "Linear_16\tshape: torch.Size([100])\n",
      "Linear_17\tparamerters num: 10000\n",
      "Linear_17\tshape: torch.Size([100, 100])\n",
      "Linear_17\tparamerters num: 100\n",
      "Linear_17\tshape: torch.Size([100])\n",
      "Linear_18\tparamerters num: 10000\n",
      "Linear_18\tshape: torch.Size([100, 100])\n",
      "Linear_18\tparamerters num: 100\n",
      "Linear_18\tshape: torch.Size([100])\n",
      "Linear_19\tparamerters num: 10000\n",
      "Linear_19\tshape: torch.Size([100, 100])\n",
      "Linear_19\tparamerters num: 100\n",
      "Linear_19\tshape: torch.Size([100])\n",
      "Linear_20\tparamerters num: 1000\n",
      "Linear_20\tshape: torch.Size([10, 100])\n",
      "Linear_20\tparamerters num: 10\n",
      "Linear_20\tshape: torch.Size([10])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "271410"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 5
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-22T07:32:18.310324Z",
     "start_time": "2025-02-22T07:32:18.303534Z"
    }
   },
   "cell_type": "code",
   "source": [
    "w = torch.empty(3, 5)\n",
    "nn.init.eye_(w)"
   ],
   "id": "54e0db95c46a2ba8",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 0., 0., 0., 0.],\n",
       "        [0., 1., 0., 0., 0.],\n",
       "        [0., 0., 1., 0., 0.]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 9
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## 训练",
   "id": "95b97a11829fe61e"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-22T07:32:29.876284Z",
     "start_time": "2025-02-22T07:32:29.775615Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluating(model, dataloader, loss_fct):\n",
    "    loss_list = []\n",
    "    pred_list = []\n",
    "    label_list = []\n",
    "    for datas, labels in dataloader:\n",
    "        datas = datas.to(device)\n",
    "        labels = labels.to(device)\n",
    "        # 前向计算\n",
    "        logits = model(datas)\n",
    "        loss = loss_fct(logits, labels)         # 验证集损失\n",
    "        loss_list.append(loss.item())\n",
    "        \n",
    "        preds = logits.argmax(axis=-1)    # 验证集预测\n",
    "        pred_list.extend(preds.cpu().numpy().tolist())\n",
    "        label_list.extend(labels.cpu().numpy().tolist())\n",
    "        \n",
    "    acc = accuracy_score(label_list, pred_list)\n",
    "    return np.mean(loss_list), acc\n"
   ],
   "id": "935fc89e0fd795a3",
   "outputs": [],
   "execution_count": 10
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-22T08:24:44.198841Z",
     "start_time": "2025-02-22T08:24:43.853542Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "\n",
    "class TensorBoardCallback:\n",
    "    def __init__(self, log_dir, flush_secs=10):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            log_dir (str): dir to write log.\n",
    "            flush_secs (int, optional): write to dsk each flush_secs seconds. Defaults to 10.\n",
    "        \"\"\"\n",
    "        self.writer = SummaryWriter(log_dir=log_dir, flush_secs=flush_secs)\n",
    "\n",
    "    def draw_model(self, model, input_shape):\n",
    "        self.writer.add_graph(model, input_to_model=torch.randn(input_shape))\n",
    "        \n",
    "    def add_loss_scalars(self, step, loss, val_loss):\n",
    "        self.writer.add_scalars(\n",
    "            main_tag=\"training/loss\", \n",
    "            tag_scalar_dict={\"loss\": loss, \"val_loss\": val_loss},\n",
    "            global_step=step,\n",
    "            )\n",
    "        \n",
    "    def add_acc_scalars(self, step, acc, val_acc):\n",
    "        self.writer.add_scalars(\n",
    "            main_tag=\"training/accuracy\",\n",
    "            tag_scalar_dict={\"accuracy\": acc, \"val_accuracy\": val_acc},\n",
    "            global_step=step,\n",
    "        )\n",
    "        \n",
    "    def add_lr_scalars(self, step, learning_rate):\n",
    "        self.writer.add_scalars(\n",
    "            main_tag=\"training/learning_rate\",\n",
    "            tag_scalar_dict={\"learning_rate\": learning_rate},\n",
    "            global_step=step,\n",
    "            \n",
    "        )\n",
    "    \n",
    "    def __call__(self, step, **kwargs):\n",
    "        # add loss\n",
    "        loss = kwargs.pop(\"loss\", None)\n",
    "        val_loss = kwargs.pop(\"val_loss\", None)\n",
    "        if loss is not None and val_loss is not None:\n",
    "            self.add_loss_scalars(step, loss, val_loss)\n",
    "        # add acc\n",
    "        acc = kwargs.pop(\"acc\", None)\n",
    "        val_acc = kwargs.pop(\"val_acc\", None)\n",
    "        if acc is not None and val_acc is not None:\n",
    "            self.add_acc_scalars(step, acc, val_acc)\n",
    "        # add lr\n",
    "        learning_rate = kwargs.pop(\"lr\", None)\n",
    "        if learning_rate is not None:\n",
    "            self.add_lr_scalars(step, learning_rate)\n"
   ],
   "id": "7cac1e8b8e6d0cc6",
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'tensorboard'",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mModuleNotFoundError\u001B[0m                       Traceback (most recent call last)",
      "Cell \u001B[1;32mIn[11], line 1\u001B[0m\n\u001B[1;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mtensorboard\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m SummaryWriter\n\u001B[0;32m      4\u001B[0m \u001B[38;5;28;01mclass\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mTensorBoardCallback\u001B[39;00m:\n\u001B[0;32m      5\u001B[0m     \u001B[38;5;28;01mdef\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21m__init__\u001B[39m(\u001B[38;5;28mself\u001B[39m, log_dir, flush_secs\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m10\u001B[39m):\n",
      "File \u001B[1;32m~\\venv\\Lib\\site-packages\\torch\\utils\\tensorboard\\__init__.py:1\u001B[0m\n\u001B[1;32m----> 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mtensorboard\u001B[39;00m\n\u001B[0;32m      2\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_vendor\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mpackaging\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mversion\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m Version\n\u001B[0;32m      4\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28mhasattr\u001B[39m(tensorboard, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m__version__\u001B[39m\u001B[38;5;124m\"\u001B[39m) \u001B[38;5;129;01mor\u001B[39;00m Version(\n\u001B[0;32m      5\u001B[0m     tensorboard\u001B[38;5;241m.\u001B[39m__version__\n\u001B[0;32m      6\u001B[0m ) \u001B[38;5;241m<\u001B[39m Version(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m1.15\u001B[39m\u001B[38;5;124m\"\u001B[39m):\n",
      "\u001B[1;31mModuleNotFoundError\u001B[0m: No module named 'tensorboard'"
     ]
    }
   ],
   "execution_count": 11
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-22T08:24:59.631299Z",
     "start_time": "2025-02-22T08:24:59.625870Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class SaveCheckpointsCallback:\n",
    "    def __init__(self, save_dir, save_step=5000, save_best_only=True):\n",
    "        \"\"\"\n",
    "        Save checkpoints each save_epoch epoch. \n",
    "        We save checkpoint by epoch in this implementation.\n",
    "        Usually, training scripts with pytorch evaluating model and save checkpoint by step.\n",
    "\n",
    "        Args:\n",
    "            save_dir (str): dir to save checkpoint\n",
    "            save_epoch (int, optional): the frequency to save checkpoint. Defaults to 1.\n",
    "            save_best_only (bool, optional): If True, only save the best model or save each model at every epoch.\n",
    "        \"\"\"\n",
    "        self.save_dir = save_dir\n",
    "        self.save_step = save_step\n",
    "        self.save_best_only = save_best_only\n",
    "        self.best_metrics = -1\n",
    "        \n",
    "        # mkdir\n",
    "        if not os.path.exists(self.save_dir):\n",
    "            os.mkdir(self.save_dir)\n",
    "        \n",
    "    def __call__(self, step, state_dict, metric=None):\n",
    "        if step % self.save_step > 0:\n",
    "            return\n",
    "        \n",
    "        if self.save_best_only:\n",
    "            assert metric is not None\n",
    "            if metric >= self.best_metrics:\n",
    "                # save checkpoints\n",
    "                torch.save(state_dict, os.path.join(self.save_dir, \"best.ckpt\"))\n",
    "                # update best metrics\n",
    "                self.best_metrics = metric\n",
    "        else:\n",
    "            torch.save(state_dict, os.path.join(self.save_dir, f\"{step}.ckpt\"))"
   ],
   "id": "2092e7e3e53c352d",
   "outputs": [],
   "execution_count": 12
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-22T15:55:16.722828Z",
     "start_time": "2025-02-22T15:55:16.537072Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "# 定义全连接层（输入维度5，输出维度3）\n",
    "layer = nn.Linear(5, 3)\n",
    "\n",
    "# 输入单个样本（形状为 [5]）\n",
    "x = torch.randn(5)  # shape: (5,)\n",
    "output = layer(x)   # shape: (3,)\n",
    "\n",
    "print(\"输出形状:\", output.shape)  # torch.Size([3])\n",
    "print(\"输出维度:\", output.shape[0])  # 3\n",
    "\n",
    "output"
   ],
   "id": "581af6213d09b8b6",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "输出形状: torch.Size([3])\n",
      "输出维度: 3\n"
     ]
    },
    {
     "ename": "IndexError",
     "evalue": "tuple index out of range",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mIndexError\u001B[0m                                Traceback (most recent call last)",
      "Cell \u001B[1;32mIn[15], line 13\u001B[0m\n\u001B[0;32m     11\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m输出形状:\u001B[39m\u001B[38;5;124m\"\u001B[39m, output\u001B[38;5;241m.\u001B[39mshape)  \u001B[38;5;66;03m# torch.Size([3])\u001B[39;00m\n\u001B[0;32m     12\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m输出维度:\u001B[39m\u001B[38;5;124m\"\u001B[39m, output\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m])  \u001B[38;5;66;03m# 3\u001B[39;00m\n\u001B[1;32m---> 13\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[43moutput\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mshape\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;241;43m1\u001B[39;49m\u001B[43m]\u001B[49m)\n\u001B[0;32m     14\u001B[0m output\n",
      "\u001B[1;31mIndexError\u001B[0m: tuple index out of range"
     ]
    }
   ],
   "execution_count": 15
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
